from tool.args import get_general_args
from tool.util import init_wandb
from train.mlbase import MLBase
from evaluate.evaluator import Evaluator
import torch.nn.functional as F

from data.dl_getter import DATASETS, n_cls, sh, input_range
import pandas as pd
import argparse
import numpy as np
import sys

import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from tool.util import set_seed, bool_flag
from datetime import datetime
import os
from data.ds import Non_dataset
from data.ds import ood_root
from data.dl_getter import get_transform
from torch.utils.data import DataLoader
from attack import attack
import pickle
from tqdm import tqdm


@torch.no_grad()
def check_acc(model, vl_dl):
    model.eval()
    correct = 0
    total = 0
    for x, y in vl_dl:
        x, y = x.cuda(), y.cuda()
        out = model(x)
        _, pred = torch.max(out.data, 1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    print(f"acc : {correct / total}")


def main(eval):
    model = eval.model
    # tr_dl = eval.tr_dl
    dl = eval.vl_dl
    check_acc(model, dl)

    eps = 8/255.
    xs, ys = [], []
    for x, y in tqdm(dl):
        x, y = x.cuda(), y.cuda()
        xs.append(x)
        ys.append(y)
    xs = torch.cat(xs, dim=0)
    ys = torch.cat(ys, dim=0)
    with open(f'ad_sample/pgd_0.03_0.pkl', 'wb') as f:
        dict_ = {'data': xs, 'label': ys}
        pickle.dump(dict_, f)

    for steps in [1, 2, 4, 8, 16, 32, 10, 20]:
        print(f'PGD: eps : {eps} | steps : {steps}')
        attack_x, attack_y = attack(
            model=model, 
            dl=dl,
            batch_size=100,
            steps=steps,
            eps=eps,
            seed=1
        )
        with open(f'ad_sample/pgd_0.03_{steps}.pkl', 'wb') as f:
            dict_ = {'data': attack_x, 'label': attack_y}
            pickle.dump(dict_, f)

    steps = [0, 1, 2, 4, 8, 16, 32, 10, 20]
    with torch.no_grad():
        for step in steps:
            with open(f'ad_sample/pgd_0.03_{step}.pkl', 'rb') as f:
                dict_ = pickle.load(f)
                x = dict_['data']
                y = dict_['label']
            latents = []
            for it in tqdm(range(100)):
                x_ = x[it*100 : (it+1)*100].cuda()
                y_ = y[it*100 : (it+1)*100].cuda()    
                latent = model.enc(x_)
                latents.append(latent)
            latents = torch.cat(latents, dim=0)
            with open(f'ad_sample/pgd_0.03_{step}_latent.pkl', 'wb') as f:
                pickle.dump(latents, f)


#python adv_attack.py --wandb_entity eavnjeong --arch resnet34 --bsz 100 --bsz_vl 100 --exp_load eph/cifar10_resnet34_lin_4 --head lin --dataset cifar10 --method evaluate
if __name__ == '__main__':
    args = get_general_args()
    init_wandb(args)
    eval = Evaluator(MLBase(args))
    main(eval)